package com.albedo.java.plugins.database.interceptor;

import com.albedo.java.plugins.database.handler.TenantLineHandler;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
import lombok.*;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.util.*;

/**
 * @author hubin
 * @since 3.4.0
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings({"rawtypes"})
public class TenantLineInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

	private TenantLineHandler tenantLineHandler;

	@Override
	public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
		if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId()) || tenantLineHandler.ignoreMapId(ms.getId())) {
			return;
		}

		PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
		mpBs.sql(parserSingle(mpBs.sql(), null));
	}

	@Override
	public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
		PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
		MappedStatement ms = mpSh.mappedStatement();
		if (!tenantLineHandler.ignoreMapId(ms.getId())) {
			SqlCommandType sct = ms.getSqlCommandType();
			if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
				if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
					return;
				}
				PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
				mpBs.sql(parserMulti(mpBs.sql(), null));
			}
		}

	}

	@Override
	protected void processSelect(Select select, int index, String sql, Object obj) {
		processSelectBody(select.getSelectBody());
		List<WithItem> withItemsList = select.getWithItemsList();
		if (!CollectionUtils.isEmpty(withItemsList)) {
			withItemsList.forEach(this::processSelectBody);
		}
	}

	protected void processSelectBody(SelectBody selectBody) {
		if (selectBody == null) {
			return;
		}
		if (selectBody instanceof PlainSelect) {
			processPlainSelect((PlainSelect) selectBody);
		} else if (selectBody instanceof WithItem) {
			WithItem withItem = (WithItem) selectBody;
			processSelectBody(withItem.getSubSelect().getSelectBody());
		} else {
			SetOperationList operationList = (SetOperationList) selectBody;
			List<SelectBody> selectBodys = operationList.getSelects();
			if (CollectionUtils.isNotEmpty(selectBodys)) {
				selectBodys.forEach(this::processSelectBody);
			}
		}
	}

	@Override
	protected void processInsert(Insert insert, int index, String sql, Object obj) {
		if (tenantLineHandler.ignoreTable(insert.getTable().getName())) {
			// 过滤退出执行
			return;
		}
		List<Column> columns = insert.getColumns();
		if (CollectionUtils.isEmpty(columns)) {
			// 针对不给列名的insert 不处理
			return;
		}
		String tenantIdColumn = tenantLineHandler.getTenantIdColumn();
		if (tenantLineHandler.ignoreInsert(columns, tenantIdColumn)) {
			// 针对已给出租户列的insert 不处理
			return;
		}
		columns.add(new Column(tenantIdColumn));

		// fixed gitee pulls/141 duplicate update
		List<Expression> duplicateUpdateColumns = insert.getDuplicateUpdateExpressionList();
		if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
			EqualsTo equalsTo = new EqualsTo();
			equalsTo.setLeftExpression(new StringValue(tenantIdColumn));
			equalsTo.setRightExpression(tenantLineHandler.getTenantId());
			duplicateUpdateColumns.add(equalsTo);
		}

		Select select = insert.getSelect();
		if (select != null) {
			this.processInsertSelect(select.getSelectBody());
		} else if (insert.getItemsList() != null) {
			// fixed github pull/295
			ItemsList itemsList = insert.getItemsList();
			if (itemsList instanceof MultiExpressionList) {
				((MultiExpressionList) itemsList).getExpressionLists().forEach(el -> el.getExpressions().add(tenantLineHandler.getTenantId()));
			} else {
				((ExpressionList) itemsList).getExpressions().add(tenantLineHandler.getTenantId());
			}
		} else {
			throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
		}
	}

	/**
	 * update 语句处理
	 */
	@Override
	protected void processUpdate(Update update, int index, String sql, Object obj) {
		final Table table = update.getTable();
		if (tenantLineHandler.ignoreTable(table.getName())) {
			// 过滤退出执行
			return;
		}
		update.setWhere(this.andExpression(table, update.getWhere()));
	}

	/**
	 * delete 语句处理
	 */
	@Override
	protected void processDelete(Delete delete, int index, String sql, Object obj) {
		if (tenantLineHandler.ignoreTable(delete.getTable().getName())) {
			// 过滤退出执行
			return;
		}
		delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
	}

	/**
	 * delete update 语句 where 处理
	 */
	protected BinaryExpression andExpression(Table table, Expression where) {
		//获得where条件表达式
		EqualsTo equalsTo = new EqualsTo();
		equalsTo.setLeftExpression(this.getAliasColumn(table));
		equalsTo.setRightExpression(tenantLineHandler.getTenantId());
		if (null != where) {
			if (where instanceof OrExpression) {
				return new AndExpression(equalsTo, new Parenthesis(where));
			} else {
				return new AndExpression(equalsTo, where);
			}
		}
		return equalsTo;
	}


	/**
	 * 处理 insert into select
	 * <p>
	 * 进入这里表示需要 insert 的表启用了多租户,则 select 的表都启动了
	 *
	 * @param selectBody SelectBody
	 */
	protected void processInsertSelect(SelectBody selectBody) {
		PlainSelect plainSelect = (PlainSelect) selectBody;
		FromItem fromItem = plainSelect.getFromItem();
		if (fromItem instanceof Table) {
			// fixed gitee pulls/141 duplicate update
			processPlainSelect(plainSelect);
			appendSelectItem(plainSelect.getSelectItems());
		} else if (fromItem instanceof SubSelect) {
			SubSelect subSelect = (SubSelect) fromItem;
			appendSelectItem(plainSelect.getSelectItems());
			processInsertSelect(subSelect.getSelectBody());
		}
	}

	/**
	 * 追加 SelectItem
	 *
	 * @param selectItems SelectItem
	 */
	protected void appendSelectItem(List<SelectItem> selectItems) {
		if (CollectionUtils.isEmpty(selectItems)) {
			return;
		}
		if (selectItems.size() == 1) {
			SelectItem item = selectItems.get(0);
			if (item instanceof AllColumns || item instanceof AllTableColumns) {
				return;
			}
		}
		selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
	}

	/**
	 * 处理 PlainSelect
	 */
	protected void processPlainSelect(PlainSelect plainSelect) {
		FromItem fromItem = plainSelect.getFromItem();
		Expression where = plainSelect.getWhere();
		processWhereSubSelect(where);
		if (fromItem instanceof Table) {
			Table fromTable = (Table) fromItem;
			if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
				//#1186 github
				plainSelect.setWhere(builderExpression(where, fromTable));
			}
		} else {
			processFromItem(fromItem);
		}
		//#3087 github
		List<SelectItem> selectItems = plainSelect.getSelectItems();
		if (CollectionUtils.isNotEmpty(selectItems)) {
			selectItems.forEach(this::processSelectItem);
		}
		List<Join> joins = plainSelect.getJoins();
		if (CollectionUtils.isNotEmpty(joins)) {
			processJoins(joins);
		}
	}

	/**
	 * 处理where条件内的子查询
	 * <p>
	 * 支持如下:
	 * 1. in
	 * 2. =
	 * 3. >
	 * 4. <
	 * 5. >=
	 * 6. <=
	 * 7. <>
	 * 8. EXISTS
	 * 9. NOT EXISTS
	 * <p>
	 * 前提条件:
	 * 1. 子查询必须放在小括号中
	 * 2. 子查询一般放在比较操作符的右边
	 *
	 * @param where where 条件
	 */
	protected void processWhereSubSelect(Expression where) {
		if (where == null) {
			return;
		}
		if (where instanceof FromItem) {
			processFromItem((FromItem) where);
			return;
		}
		if (where.toString().indexOf("SELECT") > 0) {
			// 有子查询
			if (where instanceof BinaryExpression) {
				// 比较符号 , and , or , 等等
				BinaryExpression expression = (BinaryExpression) where;
				processWhereSubSelect(expression.getLeftExpression());
				processWhereSubSelect(expression.getRightExpression());
			} else if (where instanceof InExpression) {
				// in
				InExpression expression = (InExpression) where;
				ItemsList itemsList = expression.getRightItemsList();
				if (itemsList instanceof SubSelect) {
					processSelectBody(((SubSelect) itemsList).getSelectBody());
				}
			} else if (where instanceof ExistsExpression) {
				// exists
				ExistsExpression expression = (ExistsExpression) where;
				processWhereSubSelect(expression.getRightExpression());
			} else if (where instanceof NotExpression) {
				// not exists
				NotExpression expression = (NotExpression) where;
				processWhereSubSelect(expression.getExpression());
			} else if (where instanceof Parenthesis) {
				Parenthesis expression = (Parenthesis) where;
				processWhereSubSelect(expression.getExpression());
			}
		}
	}

	protected void processSelectItem(SelectItem selectItem) {
		if (selectItem instanceof SelectExpressionItem) {
			SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
			if (selectExpressionItem.getExpression() instanceof SubSelect) {
				processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
			} else if (selectExpressionItem.getExpression() instanceof Function) {
				processFunction((Function) selectExpressionItem.getExpression());
			}
		}
	}

	/**
	 * 处理函数
	 * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
	 * <p> fixed gitee pulls/141</p>
	 *
	 * @param function
	 */
	protected void processFunction(Function function) {
		ExpressionList parameters = function.getParameters();
		if (parameters != null) {
			parameters.getExpressions().forEach(expression -> {
				if (expression instanceof SubSelect) {
					processSelectBody(((SubSelect) expression).getSelectBody());
				} else if (expression instanceof Function) {
					processFunction((Function) expression);
				}
			});
		}
	}

	/**
	 * 处理子查询等
	 */
	protected void processFromItem(FromItem fromItem) {
		if (fromItem instanceof SubJoin) {
			SubJoin subJoin = (SubJoin) fromItem;
			if (subJoin.getJoinList() != null) {
				processJoins(subJoin.getJoinList());
			}
			if (subJoin.getLeft() != null) {
				processFromItem(subJoin.getLeft());
			}
		} else if (fromItem instanceof SubSelect) {
			SubSelect subSelect = (SubSelect) fromItem;
			if (subSelect.getSelectBody() != null) {
				processSelectBody(subSelect.getSelectBody());
			}
		} else if (fromItem instanceof ValuesList) {
			logger.debug("Perform a subquery, if you do not give us feedback");
		} else if (fromItem instanceof LateralSubSelect) {
			LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
			if (lateralSubSelect.getSubSelect() != null) {
				SubSelect subSelect = lateralSubSelect.getSubSelect();
				if (subSelect.getSelectBody() != null) {
					processSelectBody(subSelect.getSelectBody());
				}
			}
		}
	}

	/**
	 * 处理 joins
	 *
	 * @param joins join 集合
	 */
	private void processJoins(List<Join> joins) {
		//对于 on 表达式写在最后的 join，需要记录下前面多个 on 的表名
		Deque<Table> tables = new LinkedList<>();
		for (Join join : joins) {
			// 处理 on 表达式
			FromItem fromItem = join.getRightItem();
			if (fromItem instanceof Table) {
				Table fromTable = (Table) fromItem;
				// 获取 join 尾缀的 on 表达式列表
				Collection<Expression> originOnExpressions = join.getOnExpressions();
				// 正常 join on 表达式只有一个，立刻处理
				if (originOnExpressions.size() == 1) {
					processJoin(join);
					continue;
				}
				// 当前表是否忽略
				boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
				// 表名压栈，忽略的表压入 null，以便后续不处理
				tables.push(needIgnore ? null : fromTable);
				// 尾缀多个 on 表达式的时候统一处理
				if (originOnExpressions.size() > 1) {
					Collection<Expression> onExpressions = new LinkedList<>();
					for (Expression originOnExpression : originOnExpressions) {
						Table currentTable = tables.poll();
						if (currentTable == null) {
							onExpressions.add(originOnExpression);
						} else {
							onExpressions.add(builderExpression(originOnExpression, currentTable));
						}
					}
					join.setOnExpressions(onExpressions);
				}
			} else {
				// 处理右边连接的子表达式
				processFromItem(fromItem);
			}
		}
	}

	/**
	 * 处理联接语句
	 */
	protected void processJoin(Join join) {
		if (join.getRightItem() instanceof Table) {
			Table fromTable = (Table) join.getRightItem();
			if (tenantLineHandler.ignoreTable(fromTable.getName())) {
				// 过滤退出执行
				return;
			}
			// 走到这里说明 on 表达式肯定只有一个
			Collection<Expression> originOnExpressions = join.getOnExpressions();
			List<Expression> onExpressions = new LinkedList<>();
			onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
			join.setOnExpressions(onExpressions);
		}
	}

	/**
	 * 处理条件
	 */
	protected Expression builderExpression(Expression currentExpression, Table table) {
		EqualsTo equalsTo = new EqualsTo();
		equalsTo.setLeftExpression(this.getAliasColumn(table));
		equalsTo.setRightExpression(tenantLineHandler.getTenantId());
		if (currentExpression == null) {
			return equalsTo;
		}
		if (currentExpression instanceof OrExpression) {
			return new AndExpression(new Parenthesis(currentExpression), equalsTo);
		} else {
			return new AndExpression(currentExpression, equalsTo);
		}
	}

	/**
	 * 租户字段别名设置
	 * <p>tenantId 或 tableAlias.tenantId</p>
	 *
	 * @param table 表对象
	 * @return 字段
	 */
	protected Column getAliasColumn(Table table) {
		StringBuilder column = new StringBuilder();
		if (table.getAlias() != null) {
			column.append(table.getAlias().getName()).append(StringPool.DOT);
		}
		column.append(tenantLineHandler.getTenantIdColumn());
		return new Column(column.toString());
	}

	@Override
	public void setProperties(Properties properties) {
		PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
			ClassUtils::newInstance, this::setTenantLineHandler);
	}
}



